{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "import lancedb\n",
    "from lancedb.embeddings import get_registry\n",
    "from lancedb.pydantic import LanceModel, Vector\n",
    "\n",
    "from application import (\n",
    "    BURR_DOCS_DIR,\n",
    "    DEVICE,\n",
    "    DOCS_LIMIT,\n",
    "    EMS_MODEL,\n",
    "    LANCE_URI,\n",
    "    add_table,\n",
    "    assistant_message,\n",
    "    evaluator_template,\n",
    "    exa_search_query_template,\n",
    "    get_user_query,\n",
    "    lancedb_query_template,\n",
    "    load_burr_docs,\n",
    "    route_template,\n",
    "    user_message,\n",
    ")\n",
    "from application import application as adaptive_crag_app\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%reload_ext autoreload"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Add Burr Docs to LanceDB "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sentence-transformers model for embeddings\n",
    "lance_model = (\n",
    "    get_registry().get(\"sentence-transformers\").create(name=EMS_MODEL, device=DEVICE)\n",
    ")\n",
    "\n",
    "\n",
    "# define the schema of your data using Pydantic\n",
    "class LanceDoc(LanceModel):\n",
    "    text: str = lance_model.SourceField()\n",
    "    vector: Vector(dim=lance_model.ndims()) = lance_model.VectorField()  # type: ignore\n",
    "    file_name: str\n",
    "\n",
    "\n",
    "lance_db = lancedb.connect(LANCE_URI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Merging 11 docs\n",
      "Merged to 4 docs\n",
      "\n",
      "Merging 47 docs\n",
      "Merged to 12 docs\n",
      "\n",
      "Merging 7 docs\n",
      "Merged to 3 docs\n",
      "\n",
      "Merging 10 docs\n",
      "Merged to 3 docs\n",
      "\n",
      "Merging 34 docs\n",
      "Merged to 9 docs\n",
      "\n"
     ]
    }
   ],
   "source": [
    "burr_docs = load_burr_docs(burr_docs_dir=BURR_DOCS_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "burr_table = add_table(\n",
    "    db=lance_db, table_name=\"burr_docs\", data=burr_docs, schema=LanceDoc\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Hybrid Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['The ~ operator will invert a condition. For instance:\\nfrom burr.core import when, expr\\nwith_transitions(\\n    (\"from\", \"to\", ~when(foo=\"baz\"),  # will evaluate to True when foo != baz\\n    (\"from\", \"to\", ~expr(\\'epochs<=100\\')) # will evaluate to True when epochs hits 101\\n)\\nAll keys present in the condition (E.G. foo and epochs above) must be present in the state for the condition to work. It will error otherwise.\\nNote\\nThe default condition is a special case, and will always evaluate to True. It is useful for defining a “catch-all” transition that will be selected if no other condition is met. If you pass a tuple of length 2 to with_transitions, the default condition will be used.',\n",
       " 'Conditions\\nConditions have a few APIs, but the most common are the three convenience functions:\\nfrom burr.core import when, expr, default\\nwith_transitions(\\n    (\"from\", \"to\", when(foo=\"bar\"),  # will evaluate when the state has the variable \"foo\" set to the value \"bar\"\\n    (\"from\", \"to\", expr(\\'epochs>100\\')) # will evaluate to True when epochs is greater than 100\\n    (\"from\", \"to\", default)  # will always evaluate to True\\n    (\"from\", \"to\") # leaving out a third conditions we allow defaults\\n)\\nConditions are evaluated in the order they are specified, and the first one that evaluates to True will be the transition that is selected when determining which action to run next. If no condition evaluates to True, the application execution will stop early.',\n",
       " 'from burr.core import ApplicationBuilder, default, expr\\napp = (\\n    ApplicationBuilder()\\n    .with_actions(human_input, ai_response)\\n    .with_transitions(\\n        (\"human_input\", \"ai_response\"),\\n        (\"ai_response\", \"human_input\")\\n    ).with_state(chat_history=[])\\n    .with_entrypoint(\"human_input\")\\n    .build()\\n)\\nRunning\\nThere are four APIs for executing an application.\\nstep/astep\\nReturns the tuple of the action, the result of that action, and the new state. Call this if you want to run the application step-by-step.',\n",
       " '@action(reads=[\"var_from_state\"], writes=[\"var_to_update\"])\\ndef custom_action(state: State, increment_by: int = 1) -> State:\\n    result = {\"var_to_update\": state[\"var_from_state\"] + increment_by}\\n    return state.update(var_to_update=state[\"var_from_state\"] + increment_by)\\napp = ApplicationBuilder().with_actions(\\n    custom_action=custom_action\\n)...\\nThis means that the application does not need the inputs to be set.\\nClass-Based Actions\\nYou can define an action by implementing the Action class:\\nfrom burr.core import Action, State',\n",
       " 'action, result, state = application.step()\\nIf you’re in an async context, you can run astep instead:\\naction, result, state = await application.astep()\\nStep can also take in inputs as a dictionary, which will be passed to the action’s run function as keyword arguments. This is specifically meant for a “human in the loop” scenario, where the action needs to ask for input from a user. In this case, the control flow is meant to be interrupted to allow for the user to provide input. See inputs for more information.']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = (\n",
    "    burr_table.search(query=\"expr\", query_type=\"hybrid\")\n",
    "    .limit(limit=DOCS_LIMIT)\n",
    "    .to_pandas()\n",
    ")\n",
    "res.text.tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What are actions in Burr?\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Router"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<task>\n",
      "You are a world-class router for user queries.\n",
      "Given a user query, you may need some extra information for answering the query. So you need to select the best place to find the answer.\n",
      "You have the following options:\n",
      "1. You will have some vectorstore tables with information on specific topics. If the query is directly related to any of these topics, return the relevant table name.\n",
      "2. If you can answer the query directly with your own knowledge, return \"assistant\".\n",
      "3. If you don't have the answer to the query, you can search the internet for the answer. In this case, return \"web_search\".\n",
      "</task>\n",
      "\n",
      "<available_tables>\n",
      "burr_docs\n",
      "</available_tables>\n",
      "\n",
      "<chat_history>\n",
      "user: hello\n",
      "assistant: hi\n",
      "</chat_history>\n",
      "\n",
      "<query>\n",
      "What are actions in Burr?\n",
      "</query>\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    route_template(\n",
    "        table_names=lance_db.table_names(),\n",
    "        query=query,\n",
    "        chat_history=[user_message(\"hello\"), assistant_message(\"hi\")],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rewrite Query\n",
    "Our AI will return an object of this in response:\n",
    "```python\n",
    "class LanceDBQuery(BaseModel):\n",
    "    keywords: list[str] = Field(..., min_length=1, max_length=10)\n",
    "    query: str = Field(..., min_length=10, max_length=300)\n",
    "\n",
    "    def __str__(self) -> str:\n",
    "        return \", \".join(self.keywords) + \", \" + self.query\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<task>\n",
      "You are a world-class researcher with access to a vectorstore that can be one of or similar to ChromaDB, Weaviate, LanceDB, VertexAI, etc.\n",
      "Given a user query, rewrite the query to find the most relevant information in the vectorstore.\n",
      "Remember that the vectorstore has hybrid search capability, which means it can do both full-text search and vector similarity search.\n",
      "So make sure to not remove any important keywords from the query. You can even add more keywords if you think they are relevant.\n",
      "Split the query into a list of keywords and a string query.\n",
      "</task>\n",
      "\n",
      "<query>\n",
      "What are actions in Burr?\n",
      "</query>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(lancedb_query_template(query=query))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract Web Search Keywords\n",
    "Our AI will return an object of this in response:\n",
    "```python\n",
    "class ExaSearchKeywords(BaseModel):\n",
    "    keywords: list[str] = Field(..., min_length=2, max_length=5)\n",
    "\n",
    "    def __str__(self) -> str:\n",
    "        return \", \".join(self.keywords)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<task>\n",
      "You are a world-class internet researcher.\n",
      "Given a user query, extract 2-5 keywords as queries for the web search. Including the topic background and main intent.\n",
      "</task>\n",
      "\n",
      "<examples>\n",
      "\n",
      "<query>\n",
      "What is Henry Feilden’s occupation?\n",
      "</query>\n",
      "<keywords>\n",
      "Henry Feilden, occupation\n",
      "</keywords>\n",
      "\n",
      "<query>\n",
      "In what city was Billy Carlson born?\n",
      "</query>\n",
      "<keywords>\n",
      "city, Billy Carlson, born\n",
      "</keywords>\n",
      "\n",
      "<query>\n",
      "What is the religion of John Gwynn?\n",
      "</query>\n",
      "<keywords>\n",
      "religion, John Gwynn\n",
      "</keywords>\n",
      "\n",
      "<query>\n",
      "What sport does Kiribati men’s national basketball team play?\n",
      "</query>\n",
      "<keywords>\n",
      "sport, Kiribati men’s national basketball team play\n",
      "</keywords>\n",
      "\n",
      "</examples>\n",
      "\n",
      "<query>\n",
      "What are actions in Burr?\n",
      "</query>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(exa_search_query_template(query=query))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Document Relevance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<task>\n",
      "You are a world-class document relevance evaluator.\n",
      "Given a user query, does the following document have the exact information to answer the question? Answer True or False.\n",
      "</task>\n",
      "\n",
      "<query>\n",
      "What are actions in Burr?\n",
      "</query>\n",
      "\n",
      "<document>\n",
      "Actions are when things go burrr\n",
      "</document>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(evaluator_template(query=query, document=\"Actions are when things go burrr\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Application"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"726pt\" height=\"541pt\"\n",
       " viewBox=\"0.00 0.00 725.51 541.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 537)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-537 721.51,-537 721.51,4 -4,4\"/>\n",
       "<!-- router -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>router</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M567,-466C567,-466 529,-466 529,-466 523,-466 517,-460 517,-454 517,-454 517,-442 517,-442 517,-436 523,-430 529,-430 529,-430 567,-430 567,-430 573,-430 579,-436 579,-442 579,-442 579,-454 579,-454 579,-460 573,-466 567,-466\"/>\n",
       "<text text-anchor=\"middle\" x=\"548\" y=\"-444.3\" font-family=\"Times,serif\" font-size=\"14.00\">router</text>\n",
       "</g>\n",
       "<!-- rewrite_query_for_lancedb -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>rewrite_query_for_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M338,-385C338,-385 156,-385 156,-385 150,-385 144,-379 144,-373 144,-373 144,-361 144,-361 144,-355 150,-349 156,-349 156,-349 338,-349 338,-349 344,-349 350,-355 350,-361 350,-361 350,-373 350,-373 350,-379 344,-385 338,-385\"/>\n",
       "<text text-anchor=\"middle\" x=\"247\" y=\"-363.3\" font-family=\"Times,serif\" font-size=\"14.00\">rewrite_query_for_lancedb</text>\n",
       "</g>\n",
       "<!-- router&#45;&gt;rewrite_query_for_lancedb -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>router&#45;&gt;rewrite_query_for_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M516.78,-443.05C481.3,-438.17 421.84,-428.78 372,-415 346.58,-407.97 319.07,-397.91 296.35,-388.9\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"297.51,-385.59 286.92,-385.11 294.9,-392.09 297.51,-385.59\"/>\n",
       "</g>\n",
       "<!-- extract_keywords_for_exa_search -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>extract_keywords_for_exa_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M501.5,-170C501.5,-170 270.5,-170 270.5,-170 264.5,-170 258.5,-164 258.5,-158 258.5,-158 258.5,-146 258.5,-146 258.5,-140 264.5,-134 270.5,-134 270.5,-134 501.5,-134 501.5,-134 507.5,-134 513.5,-140 513.5,-146 513.5,-146 513.5,-158 513.5,-158 513.5,-164 507.5,-170 501.5,-170\"/>\n",
       "<text text-anchor=\"middle\" x=\"386\" y=\"-148.3\" font-family=\"Times,serif\" font-size=\"14.00\">extract_keywords_for_exa_search</text>\n",
       "</g>\n",
       "<!-- router&#45;&gt;extract_keywords_for_exa_search -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>router&#45;&gt;extract_keywords_for_exa_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M516.76,-445.43C469.69,-441.2 386,-425.58 386,-368 386,-368 386,-368 386,-232 386,-214.89 386,-195.7 386,-180.52\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"389.5,-180.26 386,-170.26 382.5,-180.26 389.5,-180.26\"/>\n",
       "<text text-anchor=\"middle\" x=\"453\" y=\"-296.3\" font-family=\"Times,serif\" font-size=\"14.00\">route=web_search</text>\n",
       "</g>\n",
       "<!-- ask_assistant -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>ask_assistant</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M581,-36C581,-36 493,-36 493,-36 487,-36 481,-30 481,-24 481,-24 481,-12 481,-12 481,-6 487,0 493,0 493,0 581,0 581,0 587,0 593,-6 593,-12 593,-12 593,-24 593,-24 593,-30 587,-36 581,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"537\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">ask_assistant</text>\n",
       "</g>\n",
       "<!-- router&#45;&gt;ask_assistant -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>router&#45;&gt;ask_assistant</title>\n",
       "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M548,-429.74C548,-413.74 548,-389.26 548,-368 548,-368 548,-368 548,-84 548,-71.47 545.99,-57.77 543.69,-46.19\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"547.07,-45.28 541.54,-36.25 540.23,-46.76 547.07,-45.28\"/>\n",
       "<text text-anchor=\"middle\" x=\"605.5\" y=\"-229.3\" font-family=\"Times,serif\" font-size=\"14.00\">route=assistant</text>\n",
       "</g>\n",
       "<!-- terminate -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>terminate</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M651.5,-385C651.5,-385 588.5,-385 588.5,-385 582.5,-385 576.5,-379 576.5,-373 576.5,-373 576.5,-361 576.5,-361 576.5,-355 582.5,-349 588.5,-349 588.5,-349 651.5,-349 651.5,-349 657.5,-349 663.5,-355 663.5,-361 663.5,-361 663.5,-373 663.5,-373 663.5,-379 657.5,-385 651.5,-385\"/>\n",
       "<text text-anchor=\"middle\" x=\"620\" y=\"-363.3\" font-family=\"Times,serif\" font-size=\"14.00\">terminate</text>\n",
       "</g>\n",
       "<!-- router&#45;&gt;terminate -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>router&#45;&gt;terminate</title>\n",
       "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M559.02,-429.7C565.26,-420.46 573.5,-409.15 582,-400 584.56,-397.25 587.34,-394.5 590.21,-391.85\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"592.7,-394.32 597.87,-385.07 588.06,-389.08 592.7,-394.32\"/>\n",
       "<text text-anchor=\"middle\" x=\"642.5\" y=\"-403.8\" font-family=\"Times,serif\" font-size=\"14.00\">route=terminate</text>\n",
       "</g>\n",
       "<!-- input__query -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>input__query</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"548\" cy=\"-515\" rx=\"67.69\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"548\" y=\"-511.3\" font-family=\"Times,serif\" font-size=\"14.00\">input: query</text>\n",
       "</g>\n",
       "<!-- input__query&#45;&gt;router -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>input__query&#45;&gt;router</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M548,-496.92C548,-490.7 548,-483.5 548,-476.6\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"551.5,-476.19 548,-466.19 544.5,-476.19 551.5,-476.19\"/>\n",
       "</g>\n",
       "<!-- search_lancedb -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>search_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M229,-318C229,-318 127,-318 127,-318 121,-318 115,-312 115,-306 115,-306 115,-294 115,-294 115,-288 121,-282 127,-282 127,-282 229,-282 229,-282 235,-282 241,-288 241,-294 241,-294 241,-306 241,-306 241,-312 235,-318 229,-318\"/>\n",
       "<text text-anchor=\"middle\" x=\"178\" y=\"-296.3\" font-family=\"Times,serif\" font-size=\"14.00\">search_lancedb</text>\n",
       "</g>\n",
       "<!-- rewrite_query_for_lancedb&#45;&gt;search_lancedb -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>rewrite_query_for_lancedb&#45;&gt;search_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M228.87,-348.92C221.17,-341.66 212.05,-333.07 203.67,-325.18\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"205.93,-322.5 196.25,-318.19 201.13,-327.59 205.93,-322.5\"/>\n",
       "</g>\n",
       "<!-- remove_irrelevant_lancedb_results -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>remove_irrelevant_lancedb_results</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M254,-251C254,-251 12,-251 12,-251 6,-251 0,-245 0,-239 0,-239 0,-227 0,-227 0,-221 6,-215 12,-215 12,-215 254,-215 254,-215 260,-215 266,-221 266,-227 266,-227 266,-239 266,-239 266,-245 260,-251 254,-251\"/>\n",
       "<text text-anchor=\"middle\" x=\"133\" y=\"-229.3\" font-family=\"Times,serif\" font-size=\"14.00\">remove_irrelevant_lancedb_results</text>\n",
       "</g>\n",
       "<!-- search_lancedb&#45;&gt;remove_irrelevant_lancedb_results -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>search_lancedb&#45;&gt;remove_irrelevant_lancedb_results</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M166.18,-281.92C161.46,-275.11 155.93,-267.12 150.75,-259.64\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"153.47,-257.42 144.9,-251.19 147.71,-261.4 153.47,-257.42\"/>\n",
       "</g>\n",
       "<!-- remove_irrelevant_lancedb_results&#45;&gt;extract_keywords_for_exa_search -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>remove_irrelevant_lancedb_results&#45;&gt;extract_keywords_for_exa_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" d=\"M128.33,-214.78C126.75,-204.88 126.9,-192.89 134,-185 142.45,-175.61 193.85,-168.21 248.05,-162.91\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"248.56,-166.37 258.19,-161.94 247.9,-159.4 248.56,-166.37\"/>\n",
       "<text text-anchor=\"middle\" x=\"253\" y=\"-188.8\" font-family=\"Times,serif\" font-size=\"14.00\">len(lancedb_results) &lt; docs_limit</text>\n",
       "</g>\n",
       "<!-- remove_irrelevant_lancedb_results&#45;&gt;ask_assistant -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>remove_irrelevant_lancedb_results&#45;&gt;ask_assistant</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M123.98,-214.74C120.39,-205.3 118.3,-193.78 124,-185 194.22,-76.87 268.63,-113.17 389,-67 415.61,-56.79 445.43,-46.84 471.29,-38.68\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"472.5,-41.97 480.99,-35.64 470.4,-35.29 472.5,-41.97\"/>\n",
       "</g>\n",
       "<!-- search_exa -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>search_exa</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M481.5,-103C481.5,-103 410.5,-103 410.5,-103 404.5,-103 398.5,-97 398.5,-91 398.5,-91 398.5,-79 398.5,-79 398.5,-73 404.5,-67 410.5,-67 410.5,-67 481.5,-67 481.5,-67 487.5,-67 493.5,-73 493.5,-79 493.5,-79 493.5,-91 493.5,-91 493.5,-97 487.5,-103 481.5,-103\"/>\n",
       "<text text-anchor=\"middle\" x=\"446\" y=\"-81.3\" font-family=\"Times,serif\" font-size=\"14.00\">search_exa</text>\n",
       "</g>\n",
       "<!-- extract_keywords_for_exa_search&#45;&gt;search_exa -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>extract_keywords_for_exa_search&#45;&gt;search_exa</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M401.76,-133.92C408.33,-126.81 416.07,-118.42 423.23,-110.66\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"425.92,-112.91 430.13,-103.19 420.78,-108.16 425.92,-112.91\"/>\n",
       "</g>\n",
       "<!-- search_exa&#45;&gt;ask_assistant -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>search_exa&#45;&gt;ask_assistant</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M469.91,-66.92C480.48,-59.37 493.08,-50.37 504.49,-42.22\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"506.83,-44.85 512.93,-36.19 502.76,-39.15 506.83,-44.85\"/>\n",
       "</g>\n",
       "<!-- ask_assistant&#45;&gt;router -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>ask_assistant&#45;&gt;router</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M593.46,-28.26C646.25,-38.56 717,-57.34 717,-84 717,-368 717,-368 717,-368 717,-389.8 719.06,-400.26 703,-415 686.66,-430 629.28,-438.8 589.13,-443.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"588.65,-439.8 579.07,-444.33 589.38,-446.76 588.65,-439.8\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7f85adb3c4d0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "app = adaptive_crag_app(db=lance_db)\n",
    "app.visualize(\n",
    "    output_file_path=\"statemachine\",\n",
    "    include_conditions=True,\n",
    "    include_state=False,\n",
    "    format=\"png\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = {\"query\": get_user_query()}\n",
    "while True:\n",
    "    action, result, state = app.step(inputs=inputs)  # type:ignore\n",
    "    print(f\"\\nRESULT: {result}\\n\")\n",
    "    if action.name == \"terminate\":\n",
    "        break\n",
    "    elif action.name == \"ask_assistant\":\n",
    "        inputs = {\"query\": get_user_query()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
